import pandas as pd
import json
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import torch
import argparse
import pickle
import numpy as np 



# saved_dict = pickle.load(open("./data/kairos_ontology.pkl", "rb"))[0]
file_path = "../data/kairos_ontology.xlsx"

events_ontology = pd.read_excel(file_path, sheet_name="events")
entities_ontology = pd.read_excel(file_path, sheet_name="entities")

event_types_def = events_ontology['Definition'].tolist()
entity_types_def = entities_ontology['Definition'].tolist()


node_types_dict = {}
for ii in range(len(event_types_def)):
    node_types_dict[ii] = event_types_def[ii]
for ii in range(len(entity_types_def)):
    node_types_dict[ii + 67] = entity_types_def[ii]



# event_types = saved_dict['event_types']
# entity_types_dict = saved_dict["entity_types_dict"]
# node_type_dict = {}
# for key, val in event_types.items():
#     node_type_dict[key] = val
# for key, val in entity_types_dict.items():
#     node_type_dict[key] = val + 67
# print(node_type_dict)

device = torch.device("cuda:2")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

# tokenizer.to(device)
model.to(device)

event_types_str_lst = [node_types_dict[ii] for ii in range(len(node_types_dict))]
# print(event_types_str_lst)
# print(len(event_types_str_lst))
# raise

event_type_tensors = tokenizer(event_types_str_lst, padding=True, truncation=True, return_tensors="pt")


event_type_embeddings = None

for i, j in enumerate(tqdm(event_types_str_lst)):

    output = model(input_ids=event_type_tensors["input_ids"][i:i+1].to(device), 
        attention_mask=event_type_tensors['attention_mask'][i:i+1].to(device))['pooler_output']

    # print(output['last_hidden_state'])
    
    # print(output.shape)
    curr_out = np.array(output.squeeze().tolist()).reshape((1, -1))

    if event_type_embeddings is None:
        event_type_embeddings = curr_out
    else:
        event_type_embeddings = np.concatenate((event_type_embeddings, curr_out))

print(event_type_embeddings.shape)


with open("../data/kairos_ontology_embeddings.pkl", "wb") as f:
    pickle.dump(event_type_embeddings, f)




















